# -*- coding: utf-8 -*-
# The script was originally run on a super computing node with sbatch file attached
# Two options are possible
# 1. to run on a super computing node, set up conda environment 
# before running by the following command
# conda env create -f environment.yml
# 2. to run on Google Colab, first install python modules by running
# pip install transformers seaborn
# Then run the code, preferably on a GPU node
# !python train_bert_ac.py

## import modules

import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import pandas as pd
import numpy as np


from tabulate import tabulate
from tqdm import trange
import random

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

import os

if not os.path.exists("./saved_model"):
    os.makedirs("saved_model")

## Prepare training data
df = pd.read_csv("./ac_coded1.csv")

df2 = pd.read_csv("./ac_coded2.csv")

df2 = df2[['content','category']]


## Merge data

df = df[['text', 'category']]

# rename dataframe columns
df2 = df2.rename(columns={"content": "text", "code": "category"}, errors="raise")

# concatenate dataframes
frames = [df, df2]
result = pd.concat(frames)


# drop duplicate text
result = result.drop_duplicates("text").reset_index(drop=True)
result['category'].value_counts()

text = result.text.values
labels = result.category.values

## Proprocess data for BERT training

## Get label dictionary for future purposes

possible_labels = result['category'].unique()

label_dict = {}
for index, possible_label in enumerate(possible_labels):
    label_dict[possible_label] = index

# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

max_len = 0

# For every sentence...
for sent in text:

    # Tokenize the text and add `[CLS]` and `[SEP]` tokens.
    input_ids = tokenizer.encode(sent, add_special_tokens=True)

    # Update the maximum sentence length.
    max_len = max(max_len, len(input_ids))

print('Max sentence length: ', max_len)

# Tokenize all of the sentences and map the tokens to thier word IDs.
input_ids = []
attention_masks = []

# For every sentence...
for sent in text:
    # `encode_plus` will:
    #   (1) Tokenize the sentence.
    #   (2) Prepend the `[CLS]` token to the start.
    #   (3) Append the `[SEP]` token to the end.
    #   (4) Map tokens to their IDs.
    #   (5) Pad or truncate the sentence to `max_length`
    #   (6) Create attention masks for [PAD] tokens.
    encoded_dict = tokenizer.encode_plus(
                        sent,                      # Sentence to encode.
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = 128,           # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,   # Construct attn. masks.
                        return_tensors = 'pt',     # Return pytorch tensors.
                        truncation=True
                   )
    
    # Add the encoded sentence to the list.    
    input_ids.append(encoded_dict['input_ids'])
    
    # And its attention mask (simply differentiates padding from non-padding).
    attention_masks.append(encoded_dict['attention_mask'])

# Convert the lists into tensors.
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(labels)


## Split training, validation, testing data

from torch.utils.data import TensorDataset, random_split

# Combine the training inputs into a TensorDataset.
dataset = TensorDataset(input_ids, attention_masks, labels)

# Create a 80-20 train-validation split.

# Calculate the number of samples to include in each set.
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

# Divide the dataset by randomly selecting samples.
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

print('{:>5,} training samples'.format(train_size))
print('{:>5,} testing samples'.format(test_size))

# Further divide the training samples into training and validation

# Calculate the number of samples to include in each set.
train_size = int(0.75 * len(train_dataset))
val_size = len(train_dataset) - train_size

# Divide the dataset by randomly selecting samples.
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))


## Define the model

## Define model

model = BertForSequenceClassification.from_pretrained("bert-base-chinese",  
                                                      num_labels=len(label_dict),
                                                      output_attentions=False,
                                                      output_hidden_states=False)

# model.parameters



## Define other parameters and evaluation metrics


batch_size = 32

dataloader_train = DataLoader(train_dataset, 
                              sampler=RandomSampler(train_dataset), 
                              batch_size=batch_size)

dataloader_validation = DataLoader(val_dataset, 
                                   sampler=SequentialSampler(val_dataset), 
                                   batch_size=batch_size)

dataloader_test = DataLoader(test_dataset, 
                                   sampler=SequentialSampler(test_dataset), 
                                   batch_size=batch_size)


optimizer = AdamW(model.parameters(),
                  lr=1e-5, 
                  eps=1e-8)
                  
epochs = 10 ## default 10 epochs, traditionally 2 - 4

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0,
                                            num_training_steps=len(dataloader_train)*epochs)



def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='weighted')

def accuracy_per_class(preds, labels):
    label_dict_inverse = {v: k for k, v in label_dict.items()}
    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()

    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {label_dict_inverse[label]}')
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')

# wandb.init()

## Manual seding for reproducibility and move the code to CUDA (GPU)

seed_val = 17
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
torch.cuda.empty_cache()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(device)


## Evaluation function to be used once training is complete

def evaluate(dataloader_val):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_val:

        batch = tuple(b.to(device) for b in batch)
        # batch = tuple(b for b in batch)

        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
            
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals


## Train the BERT model

from tqdm import tqdm

## Training

for epoch in tqdm(range(1, epochs+1)):

    model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for batch in progress_bar:


        model.zero_grad()
        
        batch = tuple(b.to(device) for b in batch)
        # batch = tuple(b for b in batch)

        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }
       

        outputs = model(**inputs)
        
        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))}, refresh=False)
         
        
    torch.save(model.state_dict(), f'./saved_model/finetuned_BERT_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    
    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')
    
    val_loss, predictions, true_vals = evaluate(dataloader_validation)
    val_f1 = f1_score_func(predictions, true_vals)
    tqdm.write(f'Validation loss: {val_loss}')
    tqdm.write(f'F1 Score (Weighted): {val_f1}')
    tqdm.write("------")


## Evaluation on validation

torch.cuda.empty_cache()


model = BertForSequenceClassification.from_pretrained("bert-base-chinese",
                                                      num_labels=len(label_dict),
                                                      output_attentions=False,
                                                      output_hidden_states=False)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

model.load_state_dict(torch.load('./saved_model/finetuned_BERT_epoch_10.model', map_location=torch.device('cpu')))

_, predictions, true_vals = evaluate(dataloader_validation)
accuracy_per_class(predictions, true_vals)


## Write out results

label_dict_inverse = {v: k for k, v in label_dict.items()}

preds_flat = np.argmax(predictions, axis=1).flatten()
labels_flat = true_vals.flatten()


all = np.transpose(np.vstack((labels_flat, preds_flat)))
print(all.shape)
all = pd.DataFrame(all, columns=["l", "p"])
all["label"] = all["l"].map(label_dict_inverse)
all["pred"] = all["p"].map(label_dict_inverse)
all["accuracy"] = 0
mask = all.l == all.p
all.loc[mask, "accuracy"] = 1
# all.to_excel("./saved_model/results.xlsx", index=False)


## Confusion matrix

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
sns.set(font="IPAexGothic")

cm = confusion_matrix(all["l"], all["p"])  


df_cm = pd.DataFrame(cm, index = [i for i in label_dict_inverse],
                  columns = [i for i in label_dict_inverse])
plt.figure(figsize = (10,7))
sns.heatmap(df_cm, fmt='g', annot=True,cmap="Blues")
plt.savefig('heatmap_validation.png', dpi=300)

## Evaluation on testing

torch.cuda.empty_cache()


model = BertForSequenceClassification.from_pretrained("bert-base-chinese",
                                                      num_labels=len(label_dict),
                                                      output_attentions=False,
                                                      output_hidden_states=False)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

model.load_state_dict(torch.load('./saved_model/finetuned_BERT_epoch_10.model', map_location=torch.device('cpu')))

_, predictions, true_vals = evaluate(dataloader_test)
accuracy_per_class(predictions, true_vals)

label_dict_inverse = {v: k for k, v in label_dict.items()}

preds_flat = np.argmax(predictions, axis=1).flatten()
labels_flat = true_vals.flatten()


all = np.transpose(np.vstack((labels_flat, preds_flat)))
print(all.shape)
all = pd.DataFrame(all, columns=["l", "p"])
all["label"] = all["l"].map(label_dict_inverse)
all["pred"] = all["p"].map(label_dict_inverse)
all["accuracy"] = 0
mask = all.l == all.p
all.loc[mask, "accuracy"] = 1


## Confusion matrix


sns.set(font="IPAexGothic")

cm = confusion_matrix(all["l"], all["p"])  


df_cm = pd.DataFrame(cm, index = [i for i in label_dict_inverse],
                  columns = [i for i in label_dict_inverse])
plt.figure(figsize = (10,7))
sns.heatmap(df_cm, fmt='g', annot=True,cmap="Blues")
plt.savefig('heatmap_test.png', dpi=300)


